package server

import (
	"io"
	"net"
	"time"

	"github.com/pingcap/tidb/mysql"

	"github.com/zeast/logs"
)

type packetIO struct {
	conn         net.Conn
	readTimeout  time.Duration
	writeTimeout time.Duration
	sequence     uint8
}

func newPacketIO(conn net.Conn) *packetIO {
	p := &packetIO{
		conn: conn,
	}
	return p
}

func (p *packetIO) setReadTimeout(t int) {
	p.readTimeout = time.Duration(t) * time.Second
}

func (p *packetIO) setWriteTimeout(t int) {
	p.writeTimeout = time.Duration(t) * time.Second
}

func (p *packetIO) readPacket() ([]byte, error) {

	//SetReadDeadline for some times conn is broken, will case this wait forever.
	if p.readTimeout > 0 {
		p.conn.SetReadDeadline(time.Now().Add(p.readTimeout))
	}

	var header [4]byte
	if _, err := io.ReadFull(p.conn, header[:]); err != nil {
		if err == io.EOF {
			logs.Debugf("读取协议 header 错误:%v, remote:%s", err, p.conn.RemoteAddr())
		} else {
			logs.Errorf("读取协议 header 错误:%v, remote:%s", err, p.conn.RemoteAddr())
		}

		return nil, mysql.ErrBadConn
	}

	length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
	if length < 1 {
		logs.Errorf("header 数据错误: invalid payload length %d, remote:%s", length, p.conn.RemoteAddr())
		return nil, mysql.ErrBadConn
	}

	sequence := uint8(header[3])
	if sequence != p.sequence {
		logs.Errorf("包序列号错误: invalid sequence %d != %d, remote:%s", sequence, p.sequence, p.conn.RemoteAddr())
		return nil, mysql.ErrBadConn
	}

	p.sequence++

	data := make([]byte, length)
	if _, err := io.ReadFull(p.conn, data); err != nil {
		logs.Errorf("读取协议 payload 错误:%v, remote:%s", err, p.conn.RemoteAddr())
		return nil, mysql.ErrBadConn
	}
	//log.Debugf("server: head=%v, len=%v,data=%v", header, length, data)
	if length < MaxPacketSize {
		return data, nil
	}

	var buf []byte
	buf, err := p.readPacket()
	if err != nil {
		return nil, err
	}
	return append(data, buf...), nil
}

// writePacket writes data that already have header
func (p *packetIO) writePacket(data []byte) error {
	length := len(data) - 4

	for length >= MaxPacketSize {
		data[0] = 0xff
		data[1] = 0xff
		data[2] = 0xff

		data[3] = p.sequence

		if n, err := p.conn.Write(data[:4+MaxPacketSize]); err != nil {
			logs.Errorf("发送网络数据的时候发生了错误:%v, remote:%s", err, p.conn.RemoteAddr())
			return mysql.ErrBadConn
		} else if n != (4 + MaxPacketSize) {
			return mysql.ErrBadConn
		} else {
			p.sequence++
			length -= MaxPacketSize
			data = data[MaxPacketSize:]
		}
	}

	data[0] = byte(length)
	data[1] = byte(length >> 8)
	data[2] = byte(length >> 16)
	data[3] = p.sequence

	if n, err := p.conn.Write(data); err != nil {
		logs.Errorf("发送网络数据的时候发生了错误:%v, remote:%s", err, p.conn.RemoteAddr())
		return mysql.ErrBadConn
	} else if n != len(data) {
		return mysql.ErrBadConn
	} else {
		p.sequence++
		return nil
	}
}

// writeMultiPacket writes data that already have header
func (p *packetIO) writeMultiPacket(c *Client, datas [][]byte) error {
	total := c.buf.Get().Bytes()
	for _, data := range datas {
		head := make([]byte, 4)
		length := len(data)
		for length >= MaxPacketSize {
			head[0] = 0xff
			head[1] = 0xff
			head[2] = 0xff
			head[3] = p.sequence
			total = append(total, head...)
			total = append(total, data[:MaxPacketSize]...)
			p.sequence++
			length -= MaxPacketSize
			data = data[MaxPacketSize:]
		}

		head[0] = byte(length)
		head[1] = byte(length >> 8)
		head[2] = byte(length >> 16)
		head[3] = p.sequence
		//log.Debugf("client: head=%v, len=%v,data=%v", head, length, data)
		total = append(total, head...)
		total = append(total, data...)
		p.sequence++

	}
	if n, err := p.conn.Write(total); err != nil {
		logs.Errorf("发送网络数据的时候发生了错误:%v, remote:%s", err, p.conn.RemoteAddr())
		return mysql.ErrBadConn
	} else if n != len(total) {
		return mysql.ErrBadConn
	}
	return nil
}
